// Import the shared macro
use crate::include_models;
include_models!(group_norm);

#[cfg(test)]
mod tests {
    use super::*;
    use burn::tensor::{Tensor, TensorData, Tolerance, ops::FloatElem};

    use crate::backend::TestBackend;
    type FT = FloatElem<TestBackend>;

    #[test]
    fn group_norm() {
        let device = Default::default();
        let model: group_norm::Model<TestBackend> = group_norm::Model::default();

        let input = Tensor::<TestBackend, 4>::from_floats(
            [
                [
                    [[0.5, -0.14, 0.65], [1.52, -0.23, -0.23]],
                    [[1.58, 0.77, -0.47], [0.54, -0.46, -0.47]],
                    [[0.24, -1.91, -1.72], [-0.56, -1.01, 0.31]],
                    [[-0.91, -1.41, 1.47], [-0.23, 0.07, -1.42]],
                    [[-0.54, 0.11, -1.15], [0.38, -0.6, -0.29]],
                    [[-0.6, 1.85, -0.01], [-1.06, 0.82, -1.22]],
                ],
                [
                    [[0.21, -1.96, -1.33], [0.2, 0.74, 0.17]],
                    [[-0.12, -0.3, -1.48], [-0.72, -0.46, 1.06]],
                    [[0.34, -1.76, 0.32], [-0.39, -0.68, 0.61]],
                    [[1.03, 0.93, -0.84], [-0.31, 0.33, 0.98]],
                    [[-0.48, -0.19, -1.11], [-1.2, 0.81, 1.36]],
                    [[-0.07, 1., 0.36], [-0.65, 0.36, 1.54]],
                ],
            ],
            &device,
        );
        let output = model.forward(input);
        let expected = TensorData::from([
            [
                [
                    [0.25154344, -0.44127524, 0.41392279], //
                    [1.3557232, -0.53870287, -0.53870287],
                ],
                [
                    [1.52620627, 0.96459487, 0.10484407], //
                    [0.80512497, 0.11177749, 0.10484407],
                ],
                [
                    [0.76817466, -0.40737972, -0.30349366], //
                    [0.33075904, 0.08471276, 0.80644851],
                ],
                [
                    [0.36546649, 0.14186627, 1.42980372], //
                    [0.66956283, 0.80372297, 0.13739426],
                ],
                [
                    [0.89727329, 0.91652675, 0.87920467], //
                    [0.92452434, 0.89549605, 0.90467847],
                ],
                [
                    [0.19805929, 0.50608551, 0.27223703], //
                    [0.14022581, 0.37658877, 0.1201098],
                ],
            ],
            [
                [
                    [0.51190765, -1.41000622, -0.85203131], //
                    [0.50305091, 0.98131516, 0.47648066],
                ],
                [
                    [0.7569541, 0.65484651, -0.01452549], //
                    [0.41659547, 0.56408421, 1.42632599],
                ],
                [
                    [0.50130133, -0.83705754, 0.48855505], //
                    [0.03606231, -0.14875872, 0.67337606],
                ],
                [
                    [1.0211393, 0.9690137, 0.04639013], //
                    [0.32265593, 0.65625994, 0.99507652],
                ],
                [
                    [0.88949868, 0.89789333, 0.87126201], //
                    [0.86865676, 0.92684043, 0.94276133],
                ],
                [
                    [0.22297845, 0.35444493, 0.27581078], //
                    [0.15171624, 0.27581078, 0.42079251],
                ],
            ],
        ]);
        output
            .to_data()
            .assert_approx_eq::<FT>(&expected, Tolerance::default());
    }
}
